'''
Author: SlytherinGe
LastEditTime: 2021-11-03 16:54:15
'''
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET

class SSDDRboxReader(object):

    def __init__(self, anno_root) -> None:
        super().__init__()
        self.anno_root = anno_root
        self.anno_files = os.listdir(anno_root)
        self.anno_files.sort()
        self.dataset_size = len(self.anno_files)

    def _get_anno_index(self, idx):

        assert idx < self.dataset_size
        anno_file = self.anno_files[idx]
        filename = os.path.join(self.anno_root, anno_file)
        tree = ET.parse(filename)
        anno_dict = dict()
        obj_list = []
        anno_dict['filename'] = tree.find('filename').text
        anno_dict['folder'] = tree.find('folder').text
        anno_dict['size'] = dict()
        anno_dict['size']['width'] = tree.find('size').find('width').text 
        anno_dict['size']['height'] = tree.find('size').find('height').text 
        anno_dict['size']['depth'] = tree.find('size').find('depth').text 

        objects = tree.findall("object")
        for obj in objects:
            bndbox = obj.find('bndbox')
            obj_dict = {
                'name' : obj.find('name').text,
                'truncated' : obj.find('truncated').text,
                'difficult': obj.find('difficult').text,
                'bndbox':{
                    'x1':bndbox.find('x1').text,
                    'x2':bndbox.find('x2').text,
                    'x3':bndbox.find('x3').text,
                    'x4':bndbox.find('x4').text,
                    'y1':bndbox.find('y1').text,
                    'y2':bndbox.find('y2').text,
                    'y3':bndbox.find('y3').text,
                    'y4':bndbox.find('y4').text,
                }
            }
            obj_list.append(obj_dict)

        anno_dict['object'] = obj_list
        return anno_dict

    def get_mmdet_pipeline_result(self, idx):
        anno = self._get_anno_index(idx)
        shape = (int(anno['size']['height']),
                 int(anno['size']['width']),
                 int(anno['size']['depth']))
        bboxes = []
        for bndbox in anno['object']:
            p0 = (int(bndbox['bndbox']['x1']), int(bndbox['bndbox']['y1'])) #top-left
            p1 = (int(bndbox['bndbox']['x2']), int(bndbox['bndbox']['y2'])) #top-right
            p2 = (int(bndbox['bndbox']['x3']), int(bndbox['bndbox']['y3'])) #down-left
            p3 = (int(bndbox['bndbox']['x4']), int(bndbox['bndbox']['y4'])) #down-right
            bbox_rect = cv2.minAreaRect(np.array((p0, p1, p2, p3)))
            bboxes.append([bbox_rect[0][0], bbox_rect[0][1],
                           bbox_rect[1][0], bbox_rect[1][1],
                           bbox_rect[2]])
        bboxes = np.array(bboxes)
        results = dict(gt_bboxes=bboxes, img_shape=shape)
        return results